{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5ddbf0ca-4461-4666-95cd-e10ac9b3afc3",
   "metadata": {},
   "source": [
    "This notebook provides a short sample on how to train a s2s benchmark model. It uses `pytorch lightning` module, and uses MLP as an example.\n",
    "\n",
    "The complete training script can be found in the root directory of the repository `train.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0f782294-fb06-4789-a15f-fd547441a587",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "25531188-c0b4-404c-997a-43d533ea1cf0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[rank: 0] Global seed set to 42\n",
      "/burg/home/jn2808/.conda/envs/bench/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "import torch\n",
    "import lightning.pytorch as pl\n",
    "from lightning.pytorch import loggers as pl_loggers\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint\n",
    "pl.seed_everything(42)\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from chaosbench import dataset, config, utils, criterion\n",
    "from chaosbench.models import model, mlp, cnn, ae, fno, vit\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a60726-4773-4f49-a45a-a435dd471c87",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Load config filepath which consists of all the definition needed to fit/eval a model\n",
    "\n",
    "model_config_filepath = '../chaosbench/configs/segformer_s2s.yaml'\n",
    "with open(model_config_filepath, 'r') as config_filepath:\n",
    "    hyperparams = yaml.load(config_filepath, Loader=yaml.FullLoader)\n",
    "\n",
    "model_args = hyperparams['model_args']\n",
    "data_args = hyperparams['data_args']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "504d6fbb-aae3-4e41-9885-d5a7f543f519",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# This is how the hyperparameters are structured: \n",
    "# `model_args` for model specification\n",
    "# `data_args` for data definition\n",
    "\n",
    "hyperparams\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97bed92a-a3e2-458f-abc3-8e76843b290a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "# By passing the necessary hyperparameters (model + dataset)\n",
    "\n",
    "baseline = model.S2SBenchmarkModel(model_args=model_args, data_args=data_args)\n",
    "baseline.setup()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23988224-486f-4419-85f1-0b2cdc2abe54",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Setup trainer\n",
    "# Including tensorboard logger and checkpoint callback (eg. saving top-1 based on lowest validation error)\n",
    "\n",
    "tb_logger = pl_loggers.TensorBoardLogger(save_dir=f'logs/{model_args[\"model_name\"]}')\n",
    "checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    devices=-1,\n",
    "    accelerator='gpu',\n",
    "    strategy='auto',\n",
    "    max_epochs=model_args['epochs'],\n",
    "    logger=tb_logger,\n",
    "    callbacks=[checkpoint_callback]\n",
    " )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68e7a603-78b8-44c4-b043-b3698736ea82",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit the model\n",
    "# Checkpoint can be found under `logs/<MODEL_NAME>`\n",
    "\n",
    "trainer.fit(baseline)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f23d4d72-8d52-4bd7-ae69-7119bede12f1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bench",
   "language": "python",
   "name": "bench"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
